package org.zjvis.datascience.spark.algorithm;

import org.apache.commons.cli.Option;
import org.apache.spark.ml.clustering.KMeans;
import org.apache.spark.ml.clustering.KMeansModel;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.storage.StorageLevel;
import org.zjvis.datascience.spark.util.OptionHelper;
import org.zjvis.datascience.spark.util.OutputResult;
import org.zjvis.datascience.spark.util.UtilTool;
import scala.collection.JavaConversions;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * @description Spark-Kmeans聚类算子
 * @date 2021-12-23
 */
public class KmeansAlgorithm extends BaseAlgorithm {

    private Long seeds = 1234L;

    private String sourceTable;

    private int k;

    private int maxIter;

    private String disFun;

    private String targetTable;

    private String[] featureCols;

    private String modelPath;


    public KmeansAlgorithm(SparkSession sparkSession) {
        super(sparkSession);
    }

    public boolean parseParams(String[] args) {
        super.parseParams(args);
        Option sourceTableOption = new Option("s", "source", true, "source table name");
        sourceTableOption.setRequired(true);
        options.addOption(sourceTableOption);
        Option kOption = new Option("k", "clusterNum", true, "number of cluster");
        options.addOption(kOption);
        Option target = new Option("t", "target", true, "output table");
        target.setRequired(true);
        options.addOption(target);
        Option featureColsOptions = new Option("f", "featureCols", true, "feature column");
        featureColsOptions.setRequired(true);
        options.addOption(featureColsOptions);
        Option maxIterOptions = new Option("m", "maxIter", true, "max iteration");
        options.addOption(maxIterOptions);
        options.addOption(OptionHelper.getSpecOption("dm", "distanceMeasure", "distance measure function", false));

        boolean flag = this.initCommandLine(args);
        if (!flag) {
            logger.error("parser params fail!!!");
            return false;
        }
        if (cmd.hasOption("k") || cmd.hasOption("clusterNum")) {
            k = Integer.parseInt(cmd.getOptionValue("clusterNum"));
        }
        if (cmd.hasOption("f") || cmd.hasOption("featureCols")) {
            featureCols = cmd.getOptionValue("featureCols").split(",");
        }
        if (cmd.hasOption("m") || cmd.hasOption("maxIter")) {
            maxIter = Integer.parseInt(cmd.getOptionValue("maxIter", "20"));
        }
        if (cmd.hasOption("s") || cmd.hasOption("source")) {
            sourceTable = cmd.getOptionValue("source");
        }
        if (cmd.hasOption("t") || cmd.hasOption("target")) {
            targetTable = cmd.getOptionValue("target");
        }
        if (cmd.hasOption("dm") || cmd.hasOption("distanceMeasure")) {
            disFun = cmd.getOptionValue("distanceMeasure");
        }

        return true;
    }

    public boolean beginAlgorithm() {
        Dataset<Row> dataset;
        if (isHive) {
            String sql = UtilTool.buildSelectSql(featureCols, sourceTable, idCol);
            if (isSample) {
                String cacheTable = cacheUtil.modifyCacheTableName(sourceTable);
                if (cacheUtil.isCacheTableExists(cacheTable)) {
                    dataset = sparkSession.table(cacheTable);
                } else {
                    dataset = sparkSession.sql(sql).sample(sampleRatio);
                    if (!cacheUtil.cacheTableForDataset(dataset, cacheTable)) {
                        logger.error("table = {} cache fail!!!", cacheTable);
                        return false;
                    }
                    logger.info("table = {} cache success", cacheTable);
                }
            } else {
                dataset = sparkSession.sql(sql);
            }
        } else {
            if (isSample) {
                String cacheTable = cacheUtil.modifyCacheTableName(sourceTable);
                if (cacheUtil.isCacheTableExists(cacheTable)) {
                    dataset = sparkSession.table(cacheTable)
                            .select(JavaConversions.asScalaBuffer(UtilTool.selectColumns(featureCols, idCol)));
                } else {
                    Dataset<Row> cacheRdd = UtilTool.readFromGreenPlum(sparkSession, sourceTable, idCol, sampleNumber);
                    dataset = cacheRdd.select(JavaConversions.asScalaBuffer(UtilTool.selectColumns(featureCols, idCol)));
                    if (!cacheUtil.cacheTableForDataset(cacheRdd, cacheTable)) {
                        logger.error("table = {} cache fail!!!", cacheTable);
                        return false;
                    }
                    logger.info("table = {} cache success", cacheTable);
                }
            } else {
                dataset = UtilTool.readFromGreenPlum(sparkSession, sourceTable, idCol)
//                    .persist(StorageLevel.MEMORY_AND_DISK());
                        .select(JavaConversions.asScalaBuffer(UtilTool.selectColumns(featureCols, idCol)));
            }
        }

        VectorAssembler assembler = new VectorAssembler()
                .setHandleInvalid("skip")
                .setInputCols(featureCols)
                .setOutputCol("features");
        Dataset<Row> transDF = assembler.transform(dataset).persist(StorageLevel.MEMORY_AND_DISK());
        KMeans kMeans = new KMeans()
                .setK(k)
                .setSeed(seeds)
                .setFeaturesCol("features")
                .setPredictionCol("cluster_id")
                // .setDistanceMeasure(disFun)
                .setMaxIter(maxIter);
        KMeansModel model = kMeans.fit(transDF);
        Dataset<Row> predictionsTmp = model.transform(transDF).drop("features");
        Dataset<Row> predictions = predictionsTmp
                .withColumn("cluster_id_str", predictionsTmp.col("cluster_id").cast(DataTypes.StringType))
                .drop("cluster_id")
                .withColumnRenamed("cluster_id_str", "cluster_id");
        if (isHive) {
            this.saveHiveTable(predictions, targetTable);
        } else {
            UtilTool.saveGreenplumTable(predictions, targetTable);
        }
        modelPath = String.format(UtilTool.MODEL_PATH, algorithmName, uniqueKey);
        String metaPath = String.format(UtilTool.RUN_META_PATH, algorithmName, uniqueKey);
        Map<String, Object> inputKV = new HashMap<>();
        List<String> outputTables = new ArrayList<>();
        inputKV.put("modelPath", modelPath);
        outputTables.add(targetTable);
        OutputResult outputResult = this.buildOutputResult(outputTables, 0, "", inputKV);
        if (isHive) {
            this.saveOutputResultForHive(sparkSession, outputResult, metaPath);
        } else {
            this.saveOutputResultForMysql(sparkSession, outputResult);
        }
        retResult = outputResult;
        transDF.unpersist(false);
        return true;
    }
}
